import os
import time

import pandas as pd
import openai

# Experiment #2
# zero shot with improved prompting

openai.organization = "####"
openai.api_key = "####"

temperature = 1.0 # choose between 1.0 and 0.2 for replication
version = "v2"
output_name = f"experiment_2_{version}_temp{temperature}.csv"

def isSameEntity(entity_a, entity_b, version = "v2"):
    if version == "v2":
        messages = [
                {"role": "system", "content": "You are a helpful and knowledgeable assistant. You will be given two entities. Both entities refer to Congressional candidates, where 'R' stands for Republican and 'D' stands for 'Democrat'."},
            {"role": "user", "content": f"How confident are you that the following entities, {entity_a} and {entity_b}, refer to the same entity, allowing for the possibility of minor typos?\nPlease return your confidence in the range of 0 and 1 only and no other words."}
            ]
    else:
        messages = [
                {"role": "user", "content": f"How confident are you that the following entities, {entity_a} and {entity_b}, refer to the same entity, allowing for the possibility of minor typos?\nPlease return your confidence in the range of 0 and 1 only and no other words."}
            ]
	response = openai.ChatCompletion.create(
  		model="gpt-4", # gpt-4-0613
		temperature=temperature,
  		messages=messages
	)
	return float(response["choices"][0]["message"]["content"])


cache = {}

max_number_tries = 10
def process(file_name):
	df = pd.read_csv(file_name)
	df["chatGPT_score"] = None
	for i in range(df.shape[0]):
		a, b, label = df.iloc[i]["amicus"], df.iloc[i]["bonica"], df.iloc[i]["label"]
		key = a + '---' + b	
		if key in cache:
			response = cache[key]
			print("cached")
		else:
			success = False
			tries = 0
			while not success:
				try:
					tries += 1
					response = isSameEntity(a, b, version)
					success = True
				except:
					if tries >= max_number_tries:
						print("Exceeded "+ str(max_number_tries) + " tries ... ...")
						exit()
					print("index "+ str(i) +" retrying ...", tries)
					time.sleep(0.1 * tries)
		df.loc[i, "chatGPT_score"] = response
		print(i)
	df.to_csv(output_name, index=False)

start = time.time()
print("start", start)
process("Survey_Final_Results.csv")
end = time.time()
print("end", end)
print("total", end - start)
